Skip to content

feat(op): add block-level bitwise (and/or/xor/shl/shr/not), arithmeti…#260

Open
wangqin1723-max wants to merge 1 commit intohw-native-sys:mainfrom
wangqin1723-max:addop
Open

feat(op): add block-level bitwise (and/or/xor/shl/shr/not), arithmeti…#260
wangqin1723-max wants to merge 1 commit intohw-native-sys:mainfrom
wangqin1723-max:addop

Conversation

@wangqin1723-max
Copy link
Contributor

…c (rem), activation (prelu/lrelu), select, matmul variants (matmul_bias/gemv), and broadcast (row_expand) ops

  • Add bitwise block ops: and/or/xor/shl/shr/not with scalar variants (ands/ors/xors/shls/shrs)
  • Add arithmetic op: rem (remainder) with scalar variant
  • Add activation ops: prelu, lrelu
  • Add select ops: sel, sels
  • Add matmul variants: matmul_bias, gemv, gemv_acc, gemv_bias
  • Add broadcast ops: row_expand and variants (row_expand_add/sub/mul/div)
  • Add full Python DSL wrappers, IR op bindings, parser support, and tests

@coderabbitai
Copy link

coderabbitai bot commented Feb 25, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR expands the block-level operations API by adding 30+ new element-wise, arithmetic, bitwise, and matrix/vector operations across the IR and language layers. Changes span IR registration (C++), Python IR bindings, language-level Tile wrappers, module exports, parser updates, and comprehensive test coverage.

Changes

Cohort / File(s) Summary
IR Block Operation Registrations
src/ir/op/block_ops/elementwise.cpp, src/ir/op/block_ops/matmul.cpp, src/ir/op/block_ops/broadcast.cpp, src/ir/op/block_ops/unary.cpp
Adds 30+ new block operation registrations with type deduction helpers, including bitwise ops (rem, shl, shr, and, or, xor), scalar variants (rems, shls, shrs, ands, ors, xors, maxs, mins), arithmetic carry ops (addc, subc, addsc, subsc), activations (prelu, lrelu, not), selection (sel, sels), and matrix ops (matmul_bias, gemv, gemv_acc, gemv_bias). Introduces new type deduction functions (DeduceBlockOpXorBinaryType, DeduceBlockSelType, DeduceBlockMatMulBiasType, etc.).
IR Python Bindings
python/pypto/ir/op/block_ops.py
Adds 35+ Python call wrapper functions mapping to corresponding IR operations, preserving span semantics and argument unpacking for Expr types.
Language Tile API Wrappers
python/pypto/language/op/block_ops.py
Implements 35+ high-level Tile-based wrapper functions for new block operations, with docstrings and Tile-to-Expr conversions matching the IR bindings.
Module Public API Exports
python/pypto/language/__init__.py, python/pypto/language/op/__init__.py
Expands __all__ lists and import statements to expose 35+ new operations publicly, synchronizing block and tensor promotion paths.
Parser Recognition
python/pypto/language/parser/ast_parser.py
Adds new operation names (row_expand, matmul_bias, gemv, gemv_acc, gemv_bias) to dispatch and validation sets for block-only operation recognition.
Test Coverage
tests/ut/ir/operators/test_block_ops.py
Extends test suite with 80+ new test cases across multiple test classes covering bitwise, arithmetic, activation, selection, matrix/vector, and broadcast operations with shape/dtype validation.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • Hzfengsy

Poem

🐰 Hop along, dear blocks and tiles so bright,
With operations new from left and right!
XOR hops, sels bloom, and rows expand,
Matrix magic spreads across the land! ✨🎯

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.37% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main additions: block-level bitwise, arithmetic, activation, select, matmul variant, and broadcast operations. It accurately reflects the core changes in the changeset.
Description check ✅ Passed The description comprehensively details the specific operations being added across multiple categories (bitwise, arithmetic, activation, select, matmul, broadcast) with clear bullet points explaining each component and implementation scope.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @wangqin1723-max, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the set of available block-level operations within the system. It introduces a wide array of new functionalities ranging from bitwise and arithmetic operations to specialized matrix multiplication and activation functions. These additions enhance the expressiveness and capability of the intermediate representation and Python DSL, enabling more complex computations directly at the block level. The changes are thoroughly integrated with updated parser support and comprehensive unit tests to ensure correctness and stability.

Highlights

  • New Block-Level Operations: Introduced a comprehensive set of new block-level operations, including bitwise (AND, OR, XOR, SHL, SHR, NOT), arithmetic (remainder), activation functions (PReLU, Leaky ReLU), selection (SEL, SELS), and broadcast (row_expand) operations. Each new operation includes both tile-to-tile and tile-to-scalar variants where applicable.
  • Matrix Multiplication Variants: Added several specialized matrix multiplication operations: matmul_bias (matrix multiplication with bias addition), and General Matrix-Vector (GEMV) variants including gemv, gemv_acc (GEMV with accumulation), and gemv_bias (GEMV with bias addition).
  • Python DSL and IR Bindings: Implemented full Python DSL wrappers and corresponding IR op bindings for all new operations, ensuring they are accessible and usable within the Python framework.
  • Parser Support and Unit Tests: Updated the AST parser to recognize the new block operations and added extensive unit tests to validate the functionality and type deduction for all newly introduced operations.
Changelog
  • python/pypto/ir/op/block_ops.py
    • Added Python DSL wrappers for new block-level operations: remainder (rem, rems), bitwise shifts (shl, shls, shr, shrs), bitwise logical (and_, ands, or_, ors, xor, xors, not_), activation (prelu, lrelu), ternary arithmetic (addc, subc, addsc, subsc), selection (sel, sels), matrix multiplication variants (matmul_bias, gemv, gemv_acc, gemv_bias), and row expansion (row_expand).
    • Included type normalization for scalar arguments in new operations.
  • python/pypto/language/init.py
    • Updated imports to expose all newly added block-level operations at the top-level pypto.language namespace.
    • Modified the __all__ list to include the new operation names.
  • python/pypto/language/op/init.py
    • Updated imports to expose all newly added block-level operations within the pypto.language.op namespace.
    • Modified the __all__ list to include the new operation names.
  • python/pypto/language/op/block_ops.py
    • Updated the __all__ list to include all new block-level operations.
    • Implemented Python wrapper functions for all new IR block operations, handling Tile unwrapping and re-wrapping.
  • python/pypto/language/parser/ast_parser.py
    • Updated the _parse_block_op method to recognize and parse the new row_expand, matmul_bias, gemv, gemv_acc, and gemv_bias operations from the AST.
  • src/ir/op/block_ops/broadcast.cpp
    • Registered the block.row_expand operation, including its description, arguments, and type deduction logic.
  • src/ir/op/block_ops/elementwise.cpp
    • Registered block.rem and block.rems operations for element-wise remainder.
    • Registered bitwise shift operations: block.shl, block.shls, block.shr, block.shrs.
    • Registered scalar max/min operations: block.maxs, block.mins.
    • Registered bitwise logical operations: block.and, block.ands, block.or, block.ors, block.xor, block.xors.
    • Registered activation functions: block.prelu, block.lrelu.
    • Registered ternary arithmetic operations: block.addc, block.subc, block.addsc, block.subsc.
    • Registered selection operations: block.sel, block.sels.
    • Added type deduction functions for XOR and select operations to handle specific argument requirements (e.g., temporary buffers, mask tiles).
  • src/ir/op/block_ops/matmul.cpp
    • Defined DeduceBlockMatMulBiasType for type deduction of matrix multiplication with bias.
    • Registered block.matmul_bias operation.
    • Registered General Matrix-Vector (GEMV) operations: block.gemv, block.gemv_acc, block.gemv_bias.
  • src/ir/op/block_ops/unary.cpp
    • Registered the block.not operation for element-wise bitwise NOT.
  • tests/ut/ir/operators/test_block_ops.py
    • Imported backend and BackendType for backend configuration in tests.
    • Added backend.reset_for_testing() and backend.set_backend_type(BackendType.CCE) calls to row_max_kernel and row_sum_kernel tests.
    • Added new unit tests for block.row_expand_sub, block.row_expand_div, block.row_expand_mul, and block.row_expand.
    • Added new unit tests for block.matmul_acc, block.matmul_bias, block.gemv, block.gemv_acc, and block.gemv_bias.
    • Created TestBlockBitwiseArithmeticOps class to house new tests for block.rem, block.rems, block.and, block.ands, block.or, block.ors, block.xor, block.xors, block.shl, block.shls, block.maxs, block.mins, block.shr, block.shrs, block.prelu, block.not, block.addc, block.subc, block.addsc, block.subsc, block.lrelu, block.sels, and block.sel.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant number of new block-level operations across various categories, including element-wise remainder (rem, rems), bitwise shifts (shl, shls, shr, shrs), bitwise logical operations (and_, ands, or_, ors, xor, xors, not_), activation functions (prelu, lrelu), ternary arithmetic operations (addc, subc, addsc, subsc), selection operations (sel, sels), and matrix operations (matmul_bias, gemv, gemv_acc, gemv_bias). Additionally, a row_expand operation for broadcasting and maxs/mins for scalar comparisons were added. These operations are implemented in the Python IR layer (python/pypto/ir/op/block_ops.py), exposed through the pypto.language API, and registered in the C++ IR backend with corresponding type deduction logic. Comprehensive unit tests for all new operations have been added to tests/ut/ir/operators/test_block_ops.py to ensure their functionality and correct IR generation. Review comments highlight an issue where _normalize_expr in scalar bitwise and remainder operations incorrectly uses DataType.FP32 for int_dtype instead of DataType.INT32, which needs correction. Another comment points out a potentially confusing error message in DeduceBlockOpXorBinaryType due to its generic use for operations beyond XOR, suggesting a more general error message or a dedicated deduction function. Finally, the reviewer noted the high level of code duplication in the new unit tests and suggested refactoring them using parameterized tests or helper methods for better maintainability.

Comment on lines 443 to 447
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The int_dtype for _normalize_expr is set to DataType.FP32. For a bitwise shift operation, the scalar shift amount should be an integer type, not a floating-point type. Using FP32 could lead to incorrect behavior or type errors during compilation. Please use an integer DataType like DataType.INT32.

This issue also exists for shrs, ands, ors, and xors.

Suggested change
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)

Comment on lines 486 to 490
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The int_dtype for _normalize_expr is set to DataType.FP32. For a bitwise shift operation, the scalar shift amount should be an integer type. Using FP32 is incorrect. Please use an integer DataType like DataType.INT32.

Suggested change
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)

Comment on lines 525 to 529
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The int_dtype for _normalize_expr is set to DataType.FP32. For a bitwise AND operation, the scalar operand should be an integer type. Using FP32 is incorrect. Please use an integer DataType like DataType.INT32.

Suggested change
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)

Comment on lines 564 to 568
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The int_dtype for _normalize_expr is set to DataType.FP32. For a bitwise OR operation, the scalar operand should be an integer type. Using FP32 is incorrect. Please use an integer DataType like DataType.INT32.

Suggested change
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)

Comment on lines 605 to 609
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The int_dtype for _normalize_expr is set to DataType.FP32. For a bitwise XOR operation, the scalar operand should be an integer type. Using FP32 is incorrect. Please use an integer DataType like DataType.INT32.

Suggested change
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)

Comment on lines 326 to 328
CHECK(As<TileType>(args[2]->GetType()))
<< "The operator " << op_name << " requires third argument (tmp) to be a TileType, but got "
<< args[2]->GetType()->TypeName();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The error message in this CHECK is specific to xor-like operations that use a temporary buffer (tmp), but this function DeduceBlockOpXorBinaryType is also used for addc and subc, where the third argument is a tile operand, not a temporary buffer. This can lead to confusing error messages.

Consider renaming this function to something more generic like DeduceBlockOpTernaryType and making the error message more general, or creating a separate deduction function for addc and subc.

A quick fix would be to generalize the error message.

Suggested change
CHECK(As<TileType>(args[2]->GetType()))
<< "The operator " << op_name << " requires third argument (tmp) to be a TileType, but got "
<< args[2]->GetType()->TypeName();
CHECK(As<TileType>(args[2]->GetType()))
<< "The operator " << op_name << " requires third argument to be a TileType, but got "
<< args[2]->GetType()->TypeName();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment on lines 142 to 144
// ============================================================================
// Registration Function for Block Row Broadcast Operations
// ============================================================================
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment block appears to be a duplicate of the one above it. Please remove it to improve code readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment on lines +1120 to +1121
class TestBlockBitwiseArithmeticOps:
"""Test suite for newly added block-level bitwise and arithmetic ops (rem, and, or, xor)."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test class, along with others in this file, contains a lot of duplicated code for testing each new operator. This makes the test suite harder to maintain. Consider refactoring these tests to use a parameterized test function or a helper method to reduce boilerplate. A single test function could iterate over a list of operator configurations (op name, arguments, expected IR string).

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (2)
tests/ut/ir/operators/test_block_ops.py (1)

382-383: Isolate backend overrides to avoid cross-test state leakage.

These lines mutate global backend state to BackendType.CCE but do not restore prior state. This can make later tests order-dependent.

💡 Suggested pytest fixture pattern
+@pytest.fixture
+def cce_backend():
+    backend.reset_for_testing()
+    backend.set_backend_type(BackendType.CCE)
+    yield
+    backend.reset_for_testing()
+
 class TestBlockReductionOps:
@@
-    def test_block_row_max(self):
+    def test_block_row_max(self, cce_backend):
@@
-        backend.reset_for_testing()
-        backend.set_backend_type(BackendType.CCE)
@@
-    def test_block_row_sum(self):
+    def test_block_row_sum(self, cce_backend):
@@
-        backend.reset_for_testing()
-        backend.set_backend_type(BackendType.CCE)

Also applies to: 408-409

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ut/ir/operators/test_block_ops.py` around lines 382 - 383, The test
directly mutates global backend state by calling backend.reset_for_testing() and
backend.set_backend_type(BackendType.CCE) without restoring it; change this to
use a pytest fixture or try/finally that captures the current backend type
(e.g., prev = backend.get_backend_type() or equivalent), calls
backend.reset_for_testing() and backend.set_backend_type(BackendType.CCE) for
the test, then restores the original value with backend.set_backend_type(prev)
in teardown; apply the same pattern to the other occurrence that sets
BackendType.CCE so tests don't leak state across runs.
python/pypto/language/op/block_ops.py (1)

1030-1186: Tighten bitwise/shift scalar APIs to integer-only at the DSL layer.

Line 1030, Line 1063, Line 1097, Line 1131, and Line 1168 currently accept float for bitwise/shift scalar variants. Restricting these to integer-like inputs gives earlier and clearer failures.

♻️ Suggested API tightening pattern
-def ands(lhs: Tile, rhs: int | float | Expr | Scalar) -> Tile:
+def ands(lhs: Tile, rhs: int | Expr | Scalar) -> Tile:
@@
-    rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs
+    if isinstance(rhs, float):
+        raise TypeError("ands rhs must be an integer scalar")
+    rhs_expr = rhs.unwrap() if isinstance(rhs, Scalar) else rhs

Apply the same pattern to ors, xors, shls, and shrs.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/language/op/block_ops.py` around lines 1030 - 1186, The
bitwise/shift scalar helper functions (ands, ors, xors, shls, shrs) currently
allow floats in their type signatures; tighten them to integer-only at the DSL
layer by removing float from the union types (change parameters from int | float
| Expr | Scalar to int | Expr | Scalar), update the corresponding docstring
"Scalar value" / "Scalar shift amount" notes as needed, and ensure any calls to
rhs.unwrap() logic remains unchanged; modify the function signatures for ands,
ors, xors, shls, and shrs accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/pypto/ir/op/block_ops.py`:
- Around line 425-611: The scalar normalization in shls, shrs, ands, ors, xors
is incorrectly passing int_dtype=DataType.FP32 which will coerce integer
literals to FP32; update the _normalize_expr calls inside the functions shls,
shrs, ands, ors, and xors to use int_dtype=DataType.INT32 (keep
float_dtype=DataType.FP32 or remove float_dtype if not needed) so scalar
literals remain integer-typed for bitwise/shift intrinsics.

In `@src/ir/op/block_ops/broadcast.cpp`:
- Around line 146-159: The type-deduction lambda for
REGISTER_OP("block.row_expand") currently checks only that the argument is a
TileType but not its rank; update the f_deduce_type lambda (the block using the
tile_type variable) to validate that tile_type->shape_.size() == 2 and emit a
CHECK/failure message that the input must be 2D [M, N] (include the actual rank
in the error string). Return the same
std::make_shared<TileType>(tile_type->shape_, tile_type->dtype_) only after this
rank check passes.

In `@src/ir/op/block_ops/elementwise.cpp`:
- Around line 214-313: The bitwise ops (REGISTER_OP entries for "block.shl",
"block.shls", "block.shr", "block.shrs", "block.and", "block.ands", "block.or",
"block.ors") currently call generic deduction helpers and thus allow float
dtypes; change their .f_deduce_type lambdas to call bitwise-specific deduction
helpers (e.g., replace calls to DeduceBlockOpElementwiseBinaryType and
DeduceBlockOpScalarBinaryType with names like
DeduceBlockOpElementwiseBitwiseBinaryType and
DeduceBlockOpScalarBitwiseBinaryType or an equivalent helper that enforces
integer/bitwise-compatible dtypes), and implement/update those helper functions
to validate/require integer or bitwise-compatible ScalarType/TileType during
type inference. Ensure the op-name string passed (e.g., "block.shl",
"block.ands", etc.) is preserved when invoking the new deducer so error messages
remain clear.
- Around line 391-433: The registered ops block.addc, block.subc, block.addsc
and block.subsc currently call DeduceBlockOpXorBinaryType /
DeduceBlockOpXorScalarType which ignore the third operand rhs2; create new
ternary deduction helpers (e.g. DeduceBlockOpTernaryBinaryType and
DeduceBlockOpTernaryScalarType) and update the f_deduce_type lambdas for
"block.addc", "block.subc", "block.addsc", and "block.subsc" to call them;
ensure the new helpers include rhs2 in type validation, broadcast compatibility
checks, and dtype promotion logic (use the same validation points as
DeduceBlockOpXor* but include args[2]/rhs2 in shape and dtype resolution and
error messages).
- Around line 338-356: DeduceBlockOpXorScalarType currently never validates the
third argument (args[2]); add a check that args[2] is a TileType and has the
same shape as the first tile: obtain auto other_tile =
As<TileType>(args[2]->GetType()), CHECK(other_tile) with an error message
referencing op_name, then CHECK(other_tile->shape_ == tile_type->shape_) (and
optionally that other_tile->dtype_ is compatible with tile_type->dtype_ or with
result_dtype if required) before returning the resulting TileType; this ensures
the third operand is a tile with a matching shape (and compatible dtype) in
DeduceBlockOpXorScalarType.

In `@src/ir/op/block_ops/matmul.cpp`:
- Around line 174-215: DeduceBlockMatMulBiasType currently ignores bias shape
and dtype when inferring output: after computing output_shape and result_dtype
from lhs_type and rhs_type, validate that bias_type->shape_ matches output_shape
(same rank and same ConstInt dimensions where present) and include
bias_type->dtype_ in the dtype promotion (e.g., promote lhs and rhs then promote
with bias, or call a 3-way PromoteDataTypes if available) before constructing
and returning the TileType; update checks to use bias_type and output_shape
variables and ensure error messages reference op_name and the offending
dimension or dtype.

In `@src/ir/op/block_ops/unary.cpp`:
- Around line 173-180: The registered op block.not currently uses
DeduceBlockUnaryType which allows floating-point tiles; change the
type-deduction lambda to enforce bitwise-compatible dtypes by checking the input
tile's dtype and rejecting non-integer/non-bool types (or call a helper like
DeduceBlockUnaryBitwiseType if available). Specifically, inside the
REGISTER_OP("block.not") f_deduce_type lambda, inspect the first arg's dtype via
the existing type-deduction utilities and return an error/invalid type when the
dtype is floating-point, ensuring only integer or boolean tile dtypes are
accepted; keep the operator name "block.not" and reuse existing deduction
patterns from other bitwise ops for consistency.

---

Nitpick comments:
In `@python/pypto/language/op/block_ops.py`:
- Around line 1030-1186: The bitwise/shift scalar helper functions (ands, ors,
xors, shls, shrs) currently allow floats in their type signatures; tighten them
to integer-only at the DSL layer by removing float from the union types (change
parameters from int | float | Expr | Scalar to int | Expr | Scalar), update the
corresponding docstring "Scalar value" / "Scalar shift amount" notes as needed,
and ensure any calls to rhs.unwrap() logic remains unchanged; modify the
function signatures for ands, ors, xors, shls, and shrs accordingly.

In `@tests/ut/ir/operators/test_block_ops.py`:
- Around line 382-383: The test directly mutates global backend state by calling
backend.reset_for_testing() and backend.set_backend_type(BackendType.CCE)
without restoring it; change this to use a pytest fixture or try/finally that
captures the current backend type (e.g., prev = backend.get_backend_type() or
equivalent), calls backend.reset_for_testing() and
backend.set_backend_type(BackendType.CCE) for the test, then restores the
original value with backend.set_backend_type(prev) in teardown; apply the same
pattern to the other occurrence that sets BackendType.CCE so tests don't leak
state across runs.

ℹ️ Review info

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 532f7d4 and 60a1c5a.

📒 Files selected for processing (10)
  • python/pypto/ir/op/block_ops.py
  • python/pypto/language/__init__.py
  • python/pypto/language/op/__init__.py
  • python/pypto/language/op/block_ops.py
  • python/pypto/language/parser/ast_parser.py
  • src/ir/op/block_ops/broadcast.cpp
  • src/ir/op/block_ops/elementwise.cpp
  • src/ir/op/block_ops/matmul.cpp
  • src/ir/op/block_ops/unary.cpp
  • tests/ut/ir/operators/test_block_ops.py

Comment on lines 425 to 611
def shls(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call:
"""Element-wise bitwise left shift of tile and scalar.

Computes lhs << rhs element-wise. Maps to the TSHLS hardware intrinsic.

Note:
The scalar shift amount must be zero or positive; negative values are
not supported by the hardware and will be rejected by codegen.

Args:
lhs: Tile (TileType)
rhs: Scalar shift amount (int/float/Expr with ScalarType); must be >= 0
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Call expression for element-wise bitwise left shift with scalar
"""
actual_span = _get_span_or_capture(span)
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
return _ir_core.create_op_call("block.shls", [lhs, rhs_expr], {}, actual_span)


def shr(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call:
"""Element-wise bitwise right shift of two tiles.

Computes lhs >> rhs element-wise. Maps to the TSHR hardware intrinsic.

Args:
lhs: Left-hand side tile (TileType)
rhs: Right-hand side tile (TileType)
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Call expression for element-wise bitwise right shift
"""
actual_span = _get_span_or_capture(span)
return _ir_core.create_op_call("block.shr", [lhs, rhs], {}, actual_span)


def shrs(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call:
"""Element-wise bitwise right shift of tile and scalar.

Computes lhs >> rhs element-wise. Maps to the TSHRS hardware intrinsic.

Note:
The scalar shift amount must be zero or positive; negative values are
not supported by the hardware and will be rejected by codegen.

Args:
lhs: Tile (TileType)
rhs: Scalar shift amount (int/float/Expr with ScalarType); must be >= 0
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Call expression for element-wise bitwise right shift with scalar
"""
actual_span = _get_span_or_capture(span)
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
return _ir_core.create_op_call("block.shrs", [lhs, rhs_expr], {}, actual_span)


def and_(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call:
"""Element-wise bitwise AND of two tiles.

Computes lhs & rhs element-wise. Maps to the TAND hardware intrinsic.

Args:
lhs: Left-hand side tile (TileType)
rhs: Right-hand side tile (TileType)
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Call expression for element-wise bitwise AND
"""
actual_span = _get_span_or_capture(span)
return _ir_core.create_op_call("block.and", [lhs, rhs], {}, actual_span)


def ands(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call:
"""Element-wise bitwise AND of tile and scalar.

Computes lhs & rhs element-wise. Maps to the TANDS hardware intrinsic.

Args:
lhs: Tile (TileType)
rhs: Scalar (int/float/Expr with ScalarType)
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Call expression for element-wise bitwise AND with scalar
"""
actual_span = _get_span_or_capture(span)
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
return _ir_core.create_op_call("block.ands", [lhs, rhs_expr], {}, actual_span)


def or_(lhs: Expr, rhs: Expr, span: Span | None = None) -> Call:
"""Element-wise bitwise OR of two tiles.

Computes lhs | rhs element-wise. Maps to the TOR hardware intrinsic.

Args:
lhs: Left-hand side tile (TileType)
rhs: Right-hand side tile (TileType)
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Call expression for element-wise bitwise OR
"""
actual_span = _get_span_or_capture(span)
return _ir_core.create_op_call("block.or", [lhs, rhs], {}, actual_span)


def ors(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call:
"""Element-wise bitwise OR of tile and scalar.

Computes lhs | rhs element-wise. Maps to the TORS hardware intrinsic.

Args:
lhs: Tile (TileType)
rhs: Scalar (int/float/Expr with ScalarType)
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Call expression for element-wise bitwise OR with scalar
"""
actual_span = _get_span_or_capture(span)
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
return _ir_core.create_op_call("block.ors", [lhs, rhs_expr], {}, actual_span)


def xor(lhs: Expr, rhs: Expr, tmp: Expr, span: Span | None = None) -> Call:
"""Element-wise bitwise XOR of two tiles.

Computes lhs ^ rhs element-wise. Maps to the TXOR hardware intrinsic.

Args:
lhs: Left-hand side tile (TileType)
rhs: Right-hand side tile (TileType)
tmp: Temporary tile (TileType) required by the hardware
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Call expression for element-wise bitwise XOR
"""
actual_span = _get_span_or_capture(span)
return _ir_core.create_op_call("block.xor", [lhs, rhs, tmp], {}, actual_span)


def xors(lhs: Expr, rhs: int | float | Expr, tmp: Expr, span: Span | None = None) -> Call:
"""Element-wise bitwise XOR of tile and scalar.

Computes lhs ^ rhs element-wise. Maps to the TXORS hardware intrinsic.

Args:
lhs: Tile (TileType)
rhs: Scalar (int/float/Expr with ScalarType)
tmp: Temporary tile (TileType) required by the hardware
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Call expression for element-wise bitwise XOR with scalar
"""
actual_span = _get_span_or_capture(span)
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
return _ir_core.create_op_call("block.xors", [lhs, rhs_expr, tmp], {}, actual_span)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Bitwise/shift scalar normalization is using FP32 and can break operand typing.

Line 444, Line 487, Line 526, Line 565, and Line 606 normalize scalar literals with int_dtype=DataType.FP32. For shls/shrs/ands/ors/xors, this can coerce integer literals into FP32 constants before emitting integer bitwise/shift ops.

🐛 Proposed fix pattern
+def _normalize_int_scalar(rhs: int | float | Expr, span: Span, arg_name: str = "rhs") -> Expr:
+    if isinstance(rhs, float):
+        raise TypeError(f"{arg_name} must be an integer scalar for bitwise/shift ops")
+    if isinstance(rhs, Expr):
+        return rhs
+    return _normalize_expr(rhs, span, int_dtype=DataType.INT32, float_dtype=DataType.FP32)
@@
 def shls(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call:
@@
-    rhs_expr = (
-        _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
-        if not isinstance(rhs, Expr)
-        else rhs
-    )
+    rhs_expr = _normalize_int_scalar(rhs, actual_span)
@@
 def shrs(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call:
@@
-    rhs_expr = (
-        _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
-        if not isinstance(rhs, Expr)
-        else rhs
-    )
+    rhs_expr = _normalize_int_scalar(rhs, actual_span)
@@
 def ands(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call:
@@
-    rhs_expr = (
-        _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
-        if not isinstance(rhs, Expr)
-        else rhs
-    )
+    rhs_expr = _normalize_int_scalar(rhs, actual_span)
@@
 def ors(lhs: Expr, rhs: int | float | Expr, span: Span | None = None) -> Call:
@@
-    rhs_expr = (
-        _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
-        if not isinstance(rhs, Expr)
-        else rhs
-    )
+    rhs_expr = _normalize_int_scalar(rhs, actual_span)
@@
 def xors(lhs: Expr, rhs: int | float | Expr, tmp: Expr, span: Span | None = None) -> Call:
@@
-    rhs_expr = (
-        _normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
-        if not isinstance(rhs, Expr)
-        else rhs
-    )
+    rhs_expr = _normalize_int_scalar(rhs, actual_span)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/ir/op/block_ops.py` around lines 425 - 611, The scalar
normalization in shls, shrs, ands, ors, xors is incorrectly passing
int_dtype=DataType.FP32 which will coerce integer literals to FP32; update the
_normalize_expr calls inside the functions shls, shrs, ands, ors, and xors to
use int_dtype=DataType.INT32 (keep float_dtype=DataType.FP32 or remove
float_dtype if not needed) so scalar literals remain integer-typed for
bitwise/shift intrinsics.

Comment on lines 214 to 313
REGISTER_OP("block.shl")
.set_op_category("BlockOp")
.set_description("Element-wise bitwise left shift of two tiles with broadcasting")
.add_argument("lhs", "Left-hand side tile (TileType)")
.add_argument("rhs", "Right-hand side tile (TileType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockOpElementwiseBinaryType(args, kwargs, "block.shl");
});

REGISTER_OP("block.shls")
.set_op_category("BlockOp")
.set_description("Element-wise bitwise left shift of tile and scalar")
.add_argument("lhs", "Tile (TileType)")
.add_argument("rhs", "Scalar (ScalarType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockOpScalarBinaryType(args, kwargs, "block.shls");
});

REGISTER_OP("block.shr")
.set_op_category("BlockOp")
.set_description("Element-wise bitwise right shift of two tiles with broadcasting")
.add_argument("lhs", "Left-hand side tile (TileType)")
.add_argument("rhs", "Right-hand side tile (TileType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockOpElementwiseBinaryType(args, kwargs, "block.shr");
});

REGISTER_OP("block.shrs")
.set_op_category("BlockOp")
.set_description("Element-wise bitwise right shift of tile and scalar")
.add_argument("lhs", "Tile (TileType)")
.add_argument("rhs", "Scalar (ScalarType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockOpScalarBinaryType(args, kwargs, "block.shrs");
});

REGISTER_OP("block.maxs")
.set_op_category("BlockOp")
.set_description("Element-wise maximum of tile and scalar")
.add_argument("lhs", "Tile (TileType)")
.add_argument("rhs", "Scalar (ScalarType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockOpScalarBinaryType(args, kwargs, "block.maxs");
});

REGISTER_OP("block.mins")
.set_op_category("BlockOp")
.set_description("Element-wise minimum of tile and scalar")
.add_argument("lhs", "Tile (TileType)")
.add_argument("rhs", "Scalar (ScalarType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockOpScalarBinaryType(args, kwargs, "block.mins");
});

REGISTER_OP("block.and")
.set_op_category("BlockOp")
.set_description("Element-wise bitwise AND of two tiles with broadcasting")
.add_argument("lhs", "Left-hand side tile (TileType)")
.add_argument("rhs", "Right-hand side tile (TileType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockOpElementwiseBinaryType(args, kwargs, "block.and");
});

REGISTER_OP("block.ands")
.set_op_category("BlockOp")
.set_description("Element-wise bitwise AND of tile and scalar")
.add_argument("lhs", "Tile (TileType)")
.add_argument("rhs", "Scalar (ScalarType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockOpScalarBinaryType(args, kwargs, "block.ands");
});

REGISTER_OP("block.or")
.set_op_category("BlockOp")
.set_description("Element-wise bitwise OR of two tiles with broadcasting")
.add_argument("lhs", "Left-hand side tile (TileType)")
.add_argument("rhs", "Right-hand side tile (TileType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockOpElementwiseBinaryType(args, kwargs, "block.or");
});

REGISTER_OP("block.ors")
.set_op_category("BlockOp")
.set_description("Element-wise bitwise OR of tile and scalar")
.add_argument("lhs", "Tile (TileType)")
.add_argument("rhs", "Scalar (ScalarType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockOpScalarBinaryType(args, kwargs, "block.ors");
});

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Bitwise op registrations should enforce bitwise-compatible dtypes.

block.shl/shls/shr/shrs/and/ands/or/ors currently use generic arithmetic deduction, so float operands are accepted. That allows invalid bitwise IR combinations at type-inference time.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ir/op/block_ops/elementwise.cpp` around lines 214 - 313, The bitwise ops
(REGISTER_OP entries for "block.shl", "block.shls", "block.shr", "block.shrs",
"block.and", "block.ands", "block.or", "block.ors") currently call generic
deduction helpers and thus allow float dtypes; change their .f_deduce_type
lambdas to call bitwise-specific deduction helpers (e.g., replace calls to
DeduceBlockOpElementwiseBinaryType and DeduceBlockOpScalarBinaryType with names
like DeduceBlockOpElementwiseBitwiseBinaryType and
DeduceBlockOpScalarBitwiseBinaryType or an equivalent helper that enforces
integer/bitwise-compatible dtypes), and implement/update those helper functions
to validate/require integer or bitwise-compatible ScalarType/TileType during
type inference. Ensure the op-name string passed (e.g., "block.shl",
"block.ands", etc.) is preserved when invoking the new deducer so error messages
remain clear.

Comment on lines +173 to +180
REGISTER_OP("block.not")
.set_op_category("BlockOp")
.set_description("Element-wise bitwise NOT of a tile")
.add_argument("tile", "Input tile (TileType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockUnaryType(args, kwargs, "block.not");
});
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Restrict block.not to bitwise-compatible dtypes.

Line 179 routes block.not through DeduceBlockUnaryType, so float tiles are accepted. For a bitwise NOT op, this should reject non-bitwise dtypes during type deduction.

💡 Proposed fix
+TypePtr DeduceBlockBitwiseUnaryType(const std::vector<ExprPtr>& args,
+                                    const std::vector<std::pair<std::string, std::any>>& kwargs,
+                                    const std::string& op_name) {
+  CHECK(args.size() == 1) << "The operator " << op_name << " requires exactly 1 argument, but got "
+                          << args.size();
+  auto tile_type = As<TileType>(args[0]->GetType());
+  CHECK(tile_type) << "The operator " << op_name << " requires argument to be a TileType, but got "
+                   << args[0]->GetType()->TypeName();
+  CHECK(!tile_type->dtype_.IsFloat())
+      << "The operator " << op_name << " requires bitwise-compatible dtype, but got "
+      << tile_type->dtype_.ToString();
+  return std::make_shared<TileType>(tile_type->shape_, tile_type->dtype_);
+}
+
 REGISTER_OP("block.not")
@@
     .f_deduce_type([](const std::vector<ExprPtr>& args,
                       const std::vector<std::pair<std::string, std::any>>& kwargs) {
-      return DeduceBlockUnaryType(args, kwargs, "block.not");
+      return DeduceBlockBitwiseUnaryType(args, kwargs, "block.not");
     });
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
REGISTER_OP("block.not")
.set_op_category("BlockOp")
.set_description("Element-wise bitwise NOT of a tile")
.add_argument("tile", "Input tile (TileType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockUnaryType(args, kwargs, "block.not");
});
TypePtr DeduceBlockBitwiseUnaryType(const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs,
const std::string& op_name) {
CHECK(args.size() == 1) << "The operator " << op_name << " requires exactly 1 argument, but got "
<< args.size();
auto tile_type = As<TileType>(args[0]->GetType());
CHECK(tile_type) << "The operator " << op_name << " requires argument to be a TileType, but got "
<< args[0]->GetType()->TypeName();
CHECK(!tile_type->dtype_.IsFloat())
<< "The operator " << op_name << " requires bitwise-compatible dtype, but got "
<< tile_type->dtype_.ToString();
return std::make_shared<TileType>(tile_type->shape_, tile_type->dtype_);
}
REGISTER_OP("block.not")
.set_op_category("BlockOp")
.set_description("Element-wise bitwise NOT of a tile")
.add_argument("tile", "Input tile (TileType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockBitwiseUnaryType(args, kwargs, "block.not");
});
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ir/op/block_ops/unary.cpp` around lines 173 - 180, The registered op
block.not currently uses DeduceBlockUnaryType which allows floating-point tiles;
change the type-deduction lambda to enforce bitwise-compatible dtypes by
checking the input tile's dtype and rejecting non-integer/non-bool types (or
call a helper like DeduceBlockUnaryBitwiseType if available). Specifically,
inside the REGISTER_OP("block.not") f_deduce_type lambda, inspect the first
arg's dtype via the existing type-deduction utilities and return an
error/invalid type when the dtype is floating-point, ensuring only integer or
boolean tile dtypes are accepted; keep the operator name "block.not" and reuse
existing deduction patterns from other bitwise ops for consistency.

@doraemonmj
Copy link
Contributor

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several new block-level operations, including rem, rems, bitwise operations (and_, ands, or_, ors, xor, xors, shl, shls, shr, shrs), maxs, mins, prelu, not_, ternary arithmetic operations (addc, subc, addsc, subsc), lrelu, selection operations (sel, sels), matrix-vector multiplication (gemv, gemv_acc, gemv_bias), and row_expand. The code review highlights several issues in the type deduction logic for ternary operations (addc, subc, addsc, subsc, prelu), matmul_bias, and rems. Specifically, the type deduction for ternary operations and matmul_bias incorrectly handles the third argument's data type promotion and shape broadcasting. The rems function also incorrectly normalizes integer rhs values to DataType.FP32 instead of DataType.INT32. Additionally, the addsc and subsc operations use an incorrectly named type deduction function (DeduceBlockOpXorScalarType) that enforces an integer type for the scalar argument and does not correctly deduce the output type.

Comment on lines +355 to +358
auto result_dtype = PromoteDataTypes(tile_type1->dtype_, tile_type2->dtype_);
CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types";
auto broadcast_result = BroadcastShapes(tile_type1->shape_, tile_type2->shape_);
CHECK(broadcast_result.success) << "The operator " << op_name << " requires compatible shapes";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The DeduceBlockOpTernaryType function only promotes the data types and broadcasts the shapes of the first two arguments, ignoring the third argument. For operations like addc(lhs, rhs, rhs2), this is incorrect as the type of the third argument should also be part of the type promotion and shape broadcasting.

Comment on lines 442 to 462
REGISTER_OP("block.addsc")
.set_op_category("BlockOp")
.set_description("Element-wise addition of tile, scalar, and tile (lhs + scalar + rhs2)")
.add_argument("lhs", "Left-hand side tile (TileType)")
.add_argument("rhs", "Scalar (ScalarType)")
.add_argument("rhs2", "Third tile (TileType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockOpXorScalarType(args, kwargs, "block.addsc");
});

REGISTER_OP("block.subsc")
.set_op_category("BlockOp")
.set_description("Element-wise subtraction of tile, scalar, and tile (lhs - scalar - rhs2)")
.add_argument("lhs", "Left-hand side tile (TileType)")
.add_argument("rhs", "Scalar (ScalarType)")
.add_argument("rhs2", "Third tile (TileType)")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceBlockOpXorScalarType(args, kwargs, "block.subsc");
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The addsc and subsc operations are using DeduceBlockOpXorScalarType for type deduction. This is incorrect for several reasons:

  1. The function is poorly named for this purpose, leading to confusion.
  2. It enforces that the scalar argument is an integer (CHECK(scalar_type->dtype_.IsInt())), which is too restrictive for addsc/subsc which should support floating-point types.
  3. It does not correctly deduce the output type. It only considers the first argument's type (lhs) and ignores the types of the scalar and the third tile argument (rhs2). The output type should be a promotion of all three inputs.
    A new type deduction helper function should be created for (Tile, Scalar, Tile) operations like addsc and subsc.

"""
actual_span = _get_span_or_capture(span)
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The rems function normalizes integer rhs values to DataType.FP32. Remainder is typically an integer operation, and this implicit conversion of an integer to a float can lead to precision issues or unexpected behavior. It should be normalized to an integer type, such as DataType.INT32, to preserve the integer nature of the operation when an integer literal is provided.

Suggested change
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
_normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32)

Comment on lines 210 to 211
auto result_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_);
CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type deduction for matmul_bias only promotes the data types of lhs and rhs, ignoring the bias argument's type. For an operation like C = lhs @ rhs + bias, the final result type should be the promotion of all three inputs: promote(promote(lhs, rhs), bias).

Suggested change
auto result_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_);
CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types";
auto matmul_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_);
CHECK(matmul_dtype) << "The operator " << op_name << " requires compatible data types for lhs and rhs";
auto result_dtype = PromoteDataTypes(*matmul_dtype, bias_type->dtype_);
CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types for matmul result and bias";

@doraemonmj
Copy link
Contributor

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a comprehensive set of new block-level operations, including bitwise operations (AND, OR, XOR, SHL, SHR, NOT), arithmetic remainder (REM), activation functions (PReLU, LReLU), selection operations (SEL, SELS), and matrix multiplication variants (MATMUL_BIAS, GEMV). The changes span across the IR definition, Python DSL wrappers, and unit tests. The implementation appears robust, with appropriate type deduction logic in C++ and consistent Python bindings. The new test cases cover the basic functionality of each added operation, ensuring correct IR generation.

Comment on lines +401 to +404
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The rems function normalizes integer scalar inputs to DataType.FP32 via _normalize_expr. This might lead to unexpected floating-point remainder behavior if the user intends integer remainder for integer inputs, especially since other bitwise scalar operations (shls, shrs, ands, ors, xors, addsc, subsc) normalize integer scalars to DataType.INT32. If the underlying TREMS hardware intrinsic supports both integer and floating-point remainder, or if integer remainder is expected for integer inputs, this implicit conversion could be problematic. Please clarify if TREMS is exclusively floating-point, or if integer remainder is also supported and intended for integer scalar inputs.

Suggested change
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs
)
_normalize_expr(rhs, actual_span, int_dtype=DataType.INT32, float_dtype=DataType.FP32)
if not isinstance(rhs, Expr)
else rhs

@doraemonmj
Copy link
Contributor

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant number of new block-level operations, including bitwise, arithmetic, activation, and matrix multiplication variants. The implementation spans the Python DSL, IR bindings, and C++ operator registration, following the existing architecture. The changes are well-structured and include corresponding tests for the new operations. My review identified a high-severity issue in the type deduction logic for matmul_bias and gemv_bias which could lead to incorrect type promotion, and a minor style issue with a duplicated comment. Overall, this is a great addition to the operator library.

Comment on lines 210 to 214
auto result_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_);
CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types";

std::vector<ExprPtr> output_shape = {lhs_shape[0], rhs_shape[1]};
return std::make_shared<TileType>(output_shape, *result_dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type deduction for matmul_bias (and gemv_bias) doesn't seem to consider the bias tensor's data type for promotion. The operation is C = (lhs @ rhs) + bias, so the final result type should be the promotion of (lhs @ rhs) and bias. Currently, it only promotes lhs and rhs. This could lead to incorrect type inference and potential precision loss if the bias has a higher precision type.

Suggested change
auto result_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_);
CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types";
std::vector<ExprPtr> output_shape = {lhs_shape[0], rhs_shape[1]};
return std::make_shared<TileType>(output_shape, *result_dtype);
auto matmul_dtype = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_);
CHECK(matmul_dtype) << "The operator " << op_name << " requires compatible lhs and rhs data types";
auto result_dtype = PromoteDataTypes(*matmul_dtype, bias_type->dtype_);
CHECK(result_dtype) << "The operator " << op_name << " requires a compatible bias data type";
return std::make_shared<TileType>(std::vector<ExprPtr>{lhs_shape[0], rhs_shape[1]}, *result_dtype);

Comment on lines 142 to 144
// ============================================================================
// Registration Function for Block Row Broadcast Operations
// ============================================================================
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment block appears to be a duplicate. Removing it will improve code clarity.

…c (rem), activation (prelu/lrelu), select, matmul variants (matmul_bias/gemv), and broadcast (row_expand) ops
"""
actual_span = _get_span_or_capture(span)
rhs_expr = (
_normalize_expr(rhs, actual_span, int_dtype=DataType.FP32, float_dtype=DataType.FP32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int_dtype should be INT32?

"matmul",
"matmul_acc",
"matmul_bias",
"gemv",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gemv是什么op

return Tile(expr=call_expr)


def and_(lhs: Tile, rhs: Tile) -> Tile:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and为啥要带下划线

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是为了和python的关键字进行区分

return DeduceBlockOpScalarBinaryType(args, kwargs, "block.mins");
});

REGISTER_OP("block.and")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

逻辑操作和位操作是否要限制非float数据才能做

auto result_dtype = PromoteDataTypes(tile_type1->dtype_, tile_type2->dtype_);
CHECK(result_dtype) << "The operator " << op_name << " requires compatible data types";
auto broadcast_result = BroadcastShapes(tile_type1->shape_, tile_type2->shape_);
CHECK(broadcast_result.success) << "The operator " << op_name << " requires compatible shapes";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

三操作数的op如何进行广播

}

// All three tiles are real inputs (addc, subc): promote dtype and broadcast shape across all three.
TypePtr DeduceBlockOpTriTileType(const std::vector<ExprPtr>& args,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个和DeduceBlockOpTernaryType似乎是一样的?

return DeduceBlockMatMulAccType(args, kwargs, "block.matmul_acc");
});

REGISTER_OP("block.matmul_bias")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pto isa有这个指令吗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants